"""
This script trains
- a policy network
- and a detector model that maps raw observation to logprobs

Test with
>>> python train_agent2.py --algo ppo --env Adversarial-v0

Limitations
- a2c not supported
- memory not supported
- eval mode removed
- acmodel and detectormodel shares optimizer configs
"""

import argparse
import time
import datetime
import torch
import torch_ac
import tensorboardX
import sys
import wandb
import os
import glob

import utils
from model import ACModel
from recurrent_model import RecurrentACModel
from detector_model import PerfectDetector, SimpleDetectorModel, RecurrentDectectorModel

# Parse arguments

parser = argparse.ArgumentParser()

## General parameters
parser.add_argument("--algo", required=True,
                    help="algorithm to use: a2c | ppo (REQUIRED)")
parser.add_argument("--env", required=True,
                    help="name of the environment to train on (REQUIRED)")
parser.add_argument("--seed", type=int, default=1,
                    help="random seed (default: 1)")
parser.add_argument("--log-interval", type=int, default=10,
                    help="number of updates between two logs (default: 10)")
parser.add_argument("--save-interval", type=int, default=100,
                    help="number of updates between two saves (default: 10, 0 means no saving)")
parser.add_argument("--procs", type=int, default=16,
                    help="number of processes (default: 16)")
parser.add_argument("--frames", type=int, default=2*10**8,
                    help="number of frames of training (default: 2*10e8)")
parser.add_argument("--checkpoint-dir", default=None)
parser.add_argument("--load-model", default=None,
                    help="Directory of the model to load")
parser.add_argument("--wandb", action="store_true", default=False,
                    help="Log the experiment with weights & biases")

# ## Evaluation parameters
# parser.add_argument("--eval", action="store_true", default=False,
#                     help="evaluate the saved model (default: False)")
# parser.add_argument("--eval-episodes", type=int,  default=5,
#                     help="number of episodes to evaluate on (default: 5)")
# parser.add_argument("--eval-env", default=None,
#                     help="name of the environment to train on (default: use the same \"env\" as training)")
# parser.add_argument("--ltl-samplers-eval", default=None, nargs='+',
#                     help="the ltl formula templates to sample from for evaluation (default: use the same \"ltl-sampler\" as training)")
# parser.add_argument("--eval-procs", type=int, default=1,
#                     help="number of processes (default: use the same \"procs\" as training)")

## Parameters for PPO algorithm
parser.add_argument("--epochs", type=int, default=8,
                    help="number of epochs for PPO (default: 8)")
parser.add_argument("--batch-size", type=int, default=256,
                    help="batch size for PPO (default: 256)")
parser.add_argument("--frames-per-proc", type=int, default=1024,
                    help="number of frames per process before update (default: 5 for A2C and 128 for PPO)")
parser.add_argument("--discount", type=float, default=0.99,
                    help="discount factor (default: 0.99)")
parser.add_argument("--lr", type=float, default=0.0003,
                    help="learning rate (default: 0.0003)")
parser.add_argument("--gae-lambda", type=float, default=0.95,
                    help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)")
parser.add_argument("--entropy-coef", type=float, default=0.01,
                    help="entropy term coefficient (default: 0.01)")
parser.add_argument("--value-loss-coef", type=float, default=0.5,
                    help="value loss term coefficient (default: 0.5)")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
                    help="maximum norm of gradient (default: 0.5)")
parser.add_argument("--optim-eps", type=float, default=1e-8,
                    help="Adam and RMSprop optimizer epsilon (default: 1e-8)")
parser.add_argument("--optim-alpha", type=float, default=0.99,
                    help="RMSprop optimizer alpha (default: 0.99)")
parser.add_argument("--clip-eps", type=float, default=0.2,
                    help="clipping epsilon for PPO (default: 0.2)")
parser.add_argument("--recurrence", type=int, default=1,
                    help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory.")

## Parameters for training the detector
parser.add_argument("--detector-epochs", type=int, default=8,
                    help="number of epochs for detector (default: 8)")
parser.add_argument("--detector-batch-size", type=int, default=256,
                    help="batch size for detector (default: 256)")
parser.add_argument("--detector-lr", type=float, default=0.0003,
                    help="learning rate for detector (default: 0.0003)")
parser.add_argument("--detector-recurrence", type=int, default=4,
                    help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the detector model to have memory.")
parser.add_argument("--dumb-ac", action="store_true", default=False,
                    help="Use a single-layer actor-critic")
parser.add_argument("--no-rm", action="store_true", default=False,
                    help="The agent is ignorant of any RM states.")
parser.add_argument("--rm-update-algo", type=str, default="rm_detector",
                    help="[rm_detector, rm_threshold, event_threshold, independent_belief, perfect_rm]")

# TODO: combine --no-rm as an option in --rm-update-algo
args = parser.parse_args()

use_mem = args.recurrence > 1
use_mem_detector = args.detector_recurrence > 1

# Set run dir

date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
name = "no-rm" if args.no_rm else args.rm_update_algo
default_model_name = f"{name}-{args.env}-seed{args.seed}"
model_name = default_model_name
storage_dir = "storage" if args.checkpoint_dir is None else args.checkpoint_dir
model_dir = utils.get_model_dir(model_name, storage_dir)
in_model_dir = None if args.load_model is None else utils.get_model_dir(args.load_model, "")

# Load loggers and Tensorboard writer

txt_logger = utils.get_txt_logger(model_dir + "/train")
csv_file, csv_logger = utils.get_csv_logger(model_dir + "/train")

if not args.wandb:
    os.environ['WANDB_MODE'] = 'disabled'

wandb.init(project='noisy-detector')
wandb.run.name = default_model_name
wandb.run.save()
config = wandb.config
config.update(args)

# tb_writer = tensorboardX.SummaryWriter(model_dir + "/train")
utils.save_config(model_dir + "/train", args)

# Log command and all script arguments

txt_logger.info("{}\n".format(" ".join(sys.argv)))
txt_logger.info("{}\n".format(args))

# Set seed for all randomness sources

utils.seed(args.seed)

# Set device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
txt_logger.info(f"Device: {device}\n")

# Load environments

envs = []
for i in range(args.procs):
    envs.append(utils.make_env(args.env, args.rm_update_algo, args.seed))

txt_logger.info("Environments loaded\n")

# Load training status
if in_model_dir:
    status = utils.get_status(in_model_dir)
else:
    try:
        status = utils.get_status(model_dir + "/train")
    except OSError:
        status = {"num_frames": 0, "update": 0}
    txt_logger.info("Training status loaded.\n")


# Load observations preprocessor
obs_space, preprocess_obss = utils.get_obss_preprocessor(envs[0])
if "vocab" in status and preprocess_obss.vocab is not None:
    preprocess_obss.vocab.load_vocab(status["vocab"])
txt_logger.info("Observations preprocessor loaded.\n")


# Load model

if use_mem_detector:
    if args.rm_update_algo in ["event_threshold", "independent_belief"]:
        detectormodel = RecurrentDectectorModel(obs_space, obs_space['events'])
    elif args.rm_update_algo in ["rm_detector", "rm_threshold"]:
        detectormodel = RecurrentDectectorModel(obs_space, obs_space['rm_state'])
    elif args.rm_update_algo == "perfect_rm":
        detectormodel = PerfectDetector(obs_space)
    else:
        raise NotImplementedError()
else:
    if args.rm_update_algo in ["event_threshold", "independent_belief"]:
        detectormodel = SimpleDetectorModel(obs_space, obs_space['events'])
    elif args.rm_update_algo in ["rm_detector", "rm_threshold"]:
        detectormodel = SimpleDetectorModel(obs_space, obs_space['rm_state'])
    elif args.rm_update_algo == "perfect_rm":
        detectormodel = PerfectDetector(obs_space)
    else:
        raise NotImplementedError()
if use_mem:
    acmodel = RecurrentACModel(envs[0], obs_space, envs[0].action_space, args.dumb_ac, args.no_rm)
else:
    acmodel = ACModel(envs[0], obs_space, envs[0].action_space,args.dumb_ac, args.no_rm)

if "model_state" in status:
    acmodel.load_state_dict(status["model_state"])
    txt_logger.info("Loading acmodel from existing run.\n")

if "detector_model_state" in status:
    detectormodel.load_state_dict(status["detector_model_state"])
    txt_logger.info("Loading detector model from existing run.\n")

acmodel.to(device)
txt_logger.info("AC Model loaded.\n")
txt_logger.info("{}\n".format(acmodel))

detectormodel.to(device)
txt_logger.info("Detector Model loaded.\n")
txt_logger.info("{}\n".format(detectormodel))

# Load algo
# TODO: customize detector learning config
algo = torch_ac.PPOAlgo(envs, acmodel, detectormodel, device, args.frames_per_proc, args.discount, args.lr, args.gae_lambda,
                        args.entropy_coef, args.value_loss_coef, args.max_grad_norm, args.recurrence, 
                        args.detector_epochs, args.detector_batch_size, args.detector_lr, args.detector_recurrence,
                        args.optim_eps, args.clip_eps, args.epochs, args.batch_size, preprocess_obss, rm_update_algo=args.rm_update_algo)

if "optimizer_state" in status:
    algo.optimizer.load_state_dict(status["optimizer_state"])
    txt_logger.info("Loading optimizer from existing run.\n")

if "optimizer_detector_state" in status:
    algo.optimizer_detector.load_state_dict(status["optimizer_detector_state"])
    txt_logger.info("Loading detector optimizer from existing run.\n")

txt_logger.info("Optimizer loaded.\n")

# # init the evaluator
# if args.eval:
#     eval_samplers = args.ltl_samplers_eval if args.ltl_samplers_eval else [args.ltl_sampler]
#     eval_env = args.eval_env if args.eval_env else args.env
#     eval_procs = args.eval_procs if args.eval_procs else args.procs
# 
#     evals = []
#     for eval_sampler in eval_samplers:
#         evals.append(utils.Eval(eval_env, model_name, eval_sampler,
#                     seed=args.seed, device=device, num_procs=eval_procs, ignoreLTL=args.ignoreLTL, progression_mode=progression_mode, gnn=args.gnn, dumb_ac = args.dumb_ac))


# Train model

num_frames = status["num_frames"]
update = status["update"]
start_time = time.time()

while num_frames < args.frames:

    # Update model parameters

    update_start_time = time.time()
    exps, logs1 = algo.collect_experiences()
    logs2 = algo.update_ac_parameters(exps)
    logs3 = algo.update_detector_parameters(exps)
    logs = {**logs1, **logs2, **logs3}
    update_end_time = time.time()

    num_frames += logs["num_frames"]
    update += 1

    # Print logs

    if update % args.log_interval == 0:
        fps = logs["num_frames"]/(update_end_time - update_start_time)
        duration = int(time.time() - start_time)

        return_per_episode = utils.synthesize(logs["return_per_episode"])
        rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"])
        average_reward_per_step = utils.average_reward_per_step(logs["return_per_episode"], logs["num_frames_per_episode"])
        average_discounted_return = utils.average_discounted_return(logs["return_per_episode"], logs["num_frames_per_episode"], args.discount)
        num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"])

        header = ["update", "frames", "FPS", "duration"]
        data = [update, num_frames, fps, duration]
        header += ["rreturn_" + key for key in rreturn_per_episode.keys()]
        data += rreturn_per_episode.values()
        header += ["average_reward_per_step", "average_discounted_return"]
        data += [average_reward_per_step, average_discounted_return]
        header += ["num_frames_" + key for key in num_frames_per_episode.keys()]
        data += num_frames_per_episode.values()
        header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"]
        data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]]
        header += ["detector_loss", "detector_grad_norm", "detector_top1_accuracy"]
        data += [logs["detector_loss"], logs["detector_grad_norm"], logs["detector_top1_accuracy"]]

        txt_logger.info(
            "U {} | F {:06} | FPS {:04.0f} | D {} | rR:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | ARPS: {:.3f} | ADR: {:.3f} | F:μσmM {:.1f} {:.1f} {} {} | H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f} | dL {:.3f} | d∇ {:.3f} | dtop1Acc {:.3f}"
            .format(*data))

        header += ["return_" + key for key in return_per_episode.keys()]
        data += return_per_episode.values()

        if status["num_frames"] == 0:
            csv_logger.writerow(header)
        csv_logger.writerow(data)
        csv_file.flush()

        for field, value in zip(header, data):
            wandb.log({field: value})
            # tb_writer.add_scalar(field, value, num_frames)

    # Save status

    if args.save_interval > 0 and update % args.save_interval == 0:
        status = {"num_frames": num_frames, "update": update,
                  "model_state": algo.acmodel.state_dict(),
                  "detector_model_state": algo.detectormodel.state_dict(),
                  "optimizer_state": algo.optimizer.state_dict()}
        if algo.optimizer_detector is not None:
            status["optimizer_detector_state"] = algo.optimizer_detector.state_dict()
        if hasattr(preprocess_obss, "vocab") and preprocess_obss.vocab is not None:
            status["vocab"] = preprocess_obss.vocab.vocab
        utils.save_status(status, model_dir + "/train")
        txt_logger.info("Status saved")

